{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.common.file_utils import cached_path\n",
    "from allennlp.data.dataset_readers import DatasetReader\n",
    "from allennlp.data.fields import LabelField, TextField\n",
    "from allennlp.data.instance import Instance\n",
    "from allennlp.data.iterators import BucketIterator\n",
    "from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer\n",
    "from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer\n",
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper\n",
    "from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.training.trainer import Trainer\n",
    "from overrides import overrides\n",
    "\n",
    "from examples.sentiment.sst_classifier import LstmClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DIM = 16\n",
    "HIDDEN_DIM = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TatoebaSentenceReader(DatasetReader):\n",
    "    def __init__(self, token_indexers: Dict[str, TokenIndexer]=None, lazy=False):\n",
    "        super().__init__(lazy=lazy)\n",
    "        self.tokenizer = CharacterTokenizer()\n",
    "        self.token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}\n",
    "\n",
    "    @overrides\n",
    "    def text_to_instance(self, tokens, label=None):\n",
    "        fields = {}\n",
    "\n",
    "        fields['tokens'] = TextField(tokens, self.token_indexers)\n",
    "        if label:\n",
    "            fields['label'] = LabelField(label)\n",
    "\n",
    "        return Instance(fields)\n",
    "\n",
    "    @overrides\n",
    "    def _read(self, file_path: str):\n",
    "        file_path = cached_path(file_path)\n",
    "        with open(file_path, \"r\") as text_file:\n",
    "            for line in text_file:\n",
    "                lang_id, sent = line.rstrip().split('\\t')\n",
    "\n",
    "                tokens = self.tokenizer.tokenize(sent)\n",
    "\n",
    "                yield self.text_to_instance(tokens, lang_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify(text: str, model: LstmClassifier):\n",
    "    tokenizer = CharacterTokenizer()\n",
    "    token_indexers = {'tokens': SingleIdTokenIndexer()}\n",
    "\n",
    "    tokens = tokenizer.tokenize(text)\n",
    "    instance = Instance({'tokens': TextField(tokens, token_indexers)})\n",
    "    logits = model.forward_on_instance(instance)['logits']\n",
    "    label_id = np.argmax(logits)\n",
    "    label = model.vocab.get_token_from_index(label_id, 'labels')\n",
    "\n",
    "    print('text: {}, label: {}'.format(text, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader = TatoebaSentenceReader()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/tatoeba/sentences.top10langs.train.tsv')\n",
    "dev_set = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/tatoeba/sentences.top10langs.dev.tsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary.from_instances(train_set,\n",
    "                                  min_count={'tokens': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),\n",
    "                            embedding_dim=EMBEDDING_DIM)\n",
    "word_embeddings = BasicTextFieldEmbedder({\"tokens\": token_embedding})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = PytorchSeq2VecWrapper(\n",
    "    torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LstmClassifier(word_embeddings, encoder, vocab, positive_label='eng')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BucketIterator(batch_size=32, sorting_keys=[(\"tokens\", \"num_tokens\")])\n",
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=train_set,\n",
    "                  validation_dataset=dev_set,\n",
    "                  num_epochs=10)\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Take your raincoat in case it rains.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Tu me recuerdas a mi padre.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Wie organisierst du das Essen am Mittag?', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify(\"Il est des cas où cette règle ne s'applique pas.\", model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Estou fazendo um passeio em um parque.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Ve, postmorgaŭ jam estas la limdato.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Credevo che sarebbe venuto.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Nem tudja, hogy én egy macska vagyok.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Nella ur nli qrib acemma deg tenwalt.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classify('Kurşun kalemin yok, değil mi?', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
